#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Provides a comprehensive embeddings toolkit for preparing and analyzing
Wikipedia link representations, including title normalization, corpus
grouping, vocabulary construction, vector encoding, TF-IDF weighted pooling,
FAISS-based similarity search, and interactive nearest-neighbor visualization
for temporal and person-level embedding studies.
'''
from __future__ import annotations

import re
import gc
from collections import Counter
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

import faiss
import numpy as np
from optree import all_leaves
import pandas as pd
from plotly import graph_objects as go
from sentence_transformers import SentenceTransformer


class EmbeddingsToolkit:
    """
    Utilities for cleaning links, building vocabularies, encoding strings, pooling
    TF–IDF group embeddings, running FAISS similarity, and plotting nearest neighbors.
    """

    # Cleaning & grouping

    @staticmethod
    def canonicalize_titles(title: str) -> str:
        """Normalize titles/links: strip and collapse whitespace/underscores."""
        return re.sub(r"[\s_]+", " ", str(title).strip())

    @staticmethod
    def build_clean_links(scientists_as_source: pd.DataFrame) -> pd.DataFrame:
        """Return a copy with normalized 'Target' strings."""
        df = scientists_as_source.copy()
        df["Target"] = df["Target"].astype(str).str.strip()
        df["Target"] = df["Target"].map(EmbeddingsToolkit.canonicalize_titles)
        return df

    @staticmethod
    def group_links_by_century(df: pd.DataFrame) -> Dict[int, List[str]]:
        """Map century_of_birth -> list of canonicalized links (repeats retained)."""
        grouped: Dict[int, List[str]] = (
            df.groupby("century_of_birth")["Target"].apply(list).to_dict()
        )
        return grouped

    @staticmethod
    def group_links_by_scientist(df: pd.DataFrame) -> Dict[str, List[str]]:
        """Map Source -> list of canonicalized links (repeats retained)."""
        grouped: Dict[str, List[str]] = df.groupby("Source")["Target"].apply(list).to_dict()
        return grouped

    # ---------- Vocabulary & embeddings ----------

    @staticmethod
    def compute_vocab(links_by_group: List[List[str]]) -> Tuple[List[str], Dict[str, int]]:
        """Build sorted unique vocabulary and {link -> row_index} map."""
        unique_links = sorted(set(link for lst in links_by_group for link in lst if link))
        link_to_idx = {link: idx for idx, link in enumerate(unique_links)}
        return unique_links, link_to_idx

    @staticmethod
    def embed_strings(
        strings: Iterable[str],
        model: SentenceTransformer,
        batch_size: int = 256,
        normalize: bool = True,
        device: Optional[str] = None,
        ) -> np.ndarray:
        """Encode strings into a (n, d) float32 NumPy array."""
        corpus = list(strings)
        vectors = model.encode(
            corpus,
            batch_size=batch_size,
            show_progress_bar=True,
            device=device,
            normalize_embeddings=normalize,
            convert_to_numpy=True,
            )
        return np.asarray(vectors, dtype=np.float32)
        
    @staticmethod
    def save_vectors(
        vectors: np.ndarray,
        vocab: List[str],
        filename_prefix: str,
        save_dir: Optional[Path],
    ) -> None:
        """Persist vectors (.npy, float16) and vocabulary (.csv)."""
        if save_dir is None:
            return
        save_dir.mkdir(parents=True, exist_ok=True)
        np.save(save_dir / f"{filename_prefix}_vectors.npy", vectors.astype(np.float16))
        pd.Series(vocab, name="link").to_csv(
            save_dir / f"{filename_prefix}_vocab.csv", index=False
        )

    # ---------- TF–IDF pooling ----------

    @staticmethod
    def build_df_counts(links_by_document: Mapping[str, Sequence[str]]) -> Counter:
        """Document frequency counts: number of documents that contain each link."""
        df_counts = Counter()
        for links in links_by_document.values():
            df_counts.update(set(links))
        return df_counts

    @staticmethod
    def compute_idf_lookup(
        links_by_document: Mapping[str, Sequence[str]],
        unique_links: Iterable[str],
    ) -> Dict[str, float]:
        """
        Add-one smoothed IDF:
        IDF(link) = log((N_docs + 1) / (DF(link) + 1)) + 1
        """
        num_documents = len(links_by_document)
        df_counts = EmbeddingsToolkit.build_df_counts(links_by_document)
    
        # More efficient: vectorized computation
        unique_links_list = list(unique_links)
        df_values = np.array([df_counts.get(link, 0) for link in unique_links_list])
        idf_values = np.log((num_documents + 1.0) / (df_values + 1.0)) + 1.0
    
        return dict(zip(unique_links_list, idf_values))

    @staticmethod
    def pool_vectors_tfidf(
        embed_matrix: np.ndarray,
        indices: List[int],
        term_frequencies: List[int],
        idf_values: List[float],
    ) -> np.ndarray:
        """TF–IDF weighted mean over selected rows, then L2-normalize."""
        # More efficient: direct NumPy operations
        selected_vectors = embed_matrix[indices]  # (k, dim)
        tf_idf_weights = np.array(term_frequencies, dtype=np.float32) * np.array(idf_values, dtype=np.float32)
    
        # Normalize weights
        tf_idf_weights = tf_idf_weights / (tf_idf_weights.sum() + 1e-12)
    
        # Weighted average
        pooled = np.average(selected_vectors, axis=0, weights=tf_idf_weights)
    
        # L2 normalize
        norm = np.linalg.norm(pooled) + 1e-12
        return pooled / norm

    @staticmethod
    def build_group_embeddings_tfidf(
        links_by_group: Mapping[str, Sequence[str]],
        link_to_idx: Mapping[str, int],
        embed_matrix: np.ndarray,
        idf_lookup: Mapping[str, float],
    ) -> Dict[str, np.ndarray]:
        """Return {group_key -> TF–IDF pooled vector}."""
        embeddings: Dict[str, np.ndarray] = {}

        for group_key, links in links_by_group.items():
            # More efficient: filter and process in one pass
            valid_links_data = [
                (link_to_idx[link], tf_val, float(idf_lookup.get(link, 0.0)))
                for link, tf_val in Counter(links).items()
                if link in link_to_idx
            ]
        
            if valid_links_data:
                indices, term_frequencies, idf_values = zip(*valid_links_data)
                vector = EmbeddingsToolkit.pool_vectors_tfidf(
                    embed_matrix=embed_matrix,
                    indices=list(indices),
                    term_frequencies=list(term_frequencies),
                    idf_values=list(idf_values),
                )
                embeddings[group_key] = vector

        return embeddings

    # FAISS similarity

    @staticmethod
    def faiss_full_similarity_df(embeddings: Dict[str, np.ndarray]) -> pd.DataFrame:
        """
        Full (n × n) cosine similarity matrix as a DataFrame.
        Uses FAISS inner product on L2-normalized rows.
        """
        keys = list(embeddings.keys())
        matrix = np.vstack([embeddings[k] for k in keys]).astype(np.float32)
        faiss.normalize_L2(matrix)

        index = faiss.IndexFlatIP(matrix.shape[1])
        index.add(matrix)

        num_items = matrix.shape[0]
        sims, nbr_indices = index.search(matrix, num_items)

        dense = np.zeros((num_items, num_items), dtype=np.float32)
        row_ids = np.arange(num_items)[:, None]
        dense[row_ids, nbr_indices] = sims

        # Cleanup
        del matrix, index, sims, nbr_indices, row_ids
        gc.collect()

        return pd.DataFrame(dense, index=keys, columns=keys)

    @staticmethod
    def faiss_topk_neighbors(
        query_embeddings: Dict[str, np.ndarray],
        index_embeddings: Dict[str, np.ndarray],
        k: int = 10,
    ) -> Dict[str, List[Tuple[str, float]]]:
        """
        Top-k cosine neighbors for each query against the index.
        Returns {query_id -> [(neighbor_id, similarity), ...]}.
        """
        query_keys = list(query_embeddings)
        index_keys = list(index_embeddings)

        queries = np.vstack([query_embeddings[k] for k in query_keys]).astype(np.float32)
        index_matrix = np.vstack([index_embeddings[k] for k in index_keys]).astype(np.float32)

        faiss.normalize_L2(queries)
        faiss.normalize_L2(index_matrix)

        index = faiss.IndexFlatIP(index_matrix.shape[1])
        index.add(index_matrix)

        k = min(k, len(index_keys))
        sims, nbrs = index.search(queries, k)

        result = {
            qk: [(index_keys[nbrs[row, col]], float(sims[row, col])) for col in range(k)]
            for row, qk in enumerate(query_keys)
        }

        # Cleanup
        del queries, index_matrix, index, sims, nbrs, query_keys, index_keys
        gc.collect()

        return result

    @staticmethod
    def get_period(century: int) -> str:
        """Simple historical period mapping used for coloring."""
        if century in {-5, -4, -3, -2, -1, 1, 2, 3, 4}:
            return "Classical Antiquity"
        if century in {5, 6, 7, 8, 9, 10, 11, 12, 13, 14}:
            return "Middle Ages"
        if century in {15, 16, 17}:
            return "Renaissance"
        if century == 18:
            return "Enlightenment"
        if century in {19, 20}:
            return "Modern Era"
        return "Other"

    @staticmethod
    def plot_with_neighbors(
        focus_id: str,
        embeddings_3d: np.ndarray,
        all_labels: List[str],
        dataframe: pd.DataFrame,
        topk: Dict[str, List[Tuple[str, float]]],
        max_neighbors: int = 10,
        figure_width: int = 1000,
        figure_height: int = 800,
        font_size: int = 14,
        title_font_size: int = 16,
        show_neighbor_names: bool = True,
        max_labeled_neighbors: int = 8,
        ) -> None:
        """
        Minimal plotting helper for a 3D projection with nearest neighbors.
        Expects:
        - embeddings_3d: (n, 3)
        - all_labels: list of names aligned with embeddings_3d rows
        - dataframe: must contain 'Source' and 'century_of_birth'
        - topk: {label -> [(neighbor, similarity), ...]}
        """
        name_to_index = {name: idx for idx, name in enumerate(all_labels)}

        # Color-blind safe palette
        period_colors = {
            "Classical Antiquity": "#4477AA",  # Blue
            "Middle Ages": "#66CCEE",          # Cyan
            "Renaissance": "#228833",          # Green
            "Enlightenment": "#CCBB44",        # Yellow
            "Modern Era": "#EE6677",           # Red
            "Other": "#BBBBBB",                # Gray
        }

        # Marker shapes for redundancy
        period_symbols = {
            "Classical Antiquity": "circle",
            "Middle Ages": "circle-open",
            "Renaissance": "diamond",
            "Enlightenment": "x",
            "Modern Era": "square",
            "Other": "cross",
        }

        background_color = "#F8FAFB"
        text_color = "#1A365D"
        grid_color = "#E2E8F0"
        highlight_color = "#2D3748"

        fig = go.Figure()

        # Background points
        fig.add_trace(
            go.Scatter3d(
                x=embeddings_3d[:, 0],
                y=embeddings_3d[:, 1],
                z=embeddings_3d[:, 2],
                mode="markers",
                marker=dict(size=2, color="#CBD5E0",
                            opacity=0.04, symbol="circle"),
                name="Other Scientists",
                text=all_labels,
                hovertemplate="<b>%{text}</b><extra></extra>",
                showlegend=True,
            )
        )

        # Focus + neighbors
        focus_idx = name_to_index[focus_id]
        focus_coords = embeddings_3d[focus_idx]

        scientist_centuries = dataframe.groupby(
            "Source")["century_of_birth"].first().to_dict()
        focus_century = scientist_centuries.get(focus_id, "Unknown")
        focus_period = EmbeddingsToolkit.get_period(
            focus_century) if isinstance(focus_century, int) else "Other"

        fig.add_trace(
            go.Scatter3d(
                x=[focus_coords[0]],
                y=[focus_coords[1]],
                z=[focus_coords[2]],
                mode="markers+text",
                marker=dict(
                    size=18,
                    color=highlight_color,
                    opacity=1.0,
                    symbol="diamond",
                    line=dict(color="#F7FAFC", width=3),
                ),
                text=[focus_id],
                textposition="top center",
                textfont=dict(size=font_size + 2, color=text_color,
                            family="Arial Black"),
                name=f"Focus: {focus_id}",
                hovertemplate=f"<b>{focus_id}</b><br>Century: {focus_century}<br>Period: {focus_period}<extra></extra>",
                showlegend=True,
            )
        )

        neighbor_list = topk.get(focus_id, [])[:max_neighbors]
        connection_traces = []  # Collect connection traces to add later

        for rank, (neighbor, sim) in enumerate(neighbor_list, start=1):
            if neighbor == focus_id or neighbor not in name_to_index:
                continue
            neighbor_idx = name_to_index[neighbor]
            neighbor_coords = embeddings_3d[neighbor_idx]

            century = scientist_centuries.get(neighbor, "Unknown")
            period = EmbeddingsToolkit.get_period(
                century) if isinstance(century, int) else "Other"
            color = period_colors.get(period, "#718096")
            symbol = period_symbols.get(period, "diamond")

            mode = "markers+text" if (show_neighbor_names and rank <=
                                    max_labeled_neighbors) else "markers"
            display_name = neighbor if len(
                neighbor) <= 20 else neighbor[:17] + "..."

            fig.add_trace(
                go.Scatter3d(
                    x=[neighbor_coords[0]],
                    y=[neighbor_coords[1]],
                    z=[neighbor_coords[2]],
                    mode=mode,
                    marker=dict(
                        size=max(14 - 0.8 * (rank - 1), 8),
                        color=color,
                        opacity=0.85,
                        symbol=symbol,
                        line=dict(color="#F7FAFC", width=2),
                    ),
                    text=[display_name] if mode == "markers+text" else None,
                    textposition="top center",
                    textfont=dict(size=max(font_size - 2, 10),
                                color=color, family="Arial"),
                    name=f"#{rank}: {neighbor}" if rank <= 5 else f"{period} (#{rank})",
                    hovertemplate=(
                        f"<b>{neighbor}</b><br>Century: {century}<br>Period: {period}<br>"
                        f"Similarity: {sim:.1%}<br>Rank: #{rank}<extra></extra>"
                    ),
                    showlegend=True if rank <= 5 else False,
                )
            )

            # Connection line (top 5)
            if rank <= 5:
                connection_trace = go.Scatter3d(
                    x=[focus_coords[0], neighbor_coords[0]],
                    y=[focus_coords[1], neighbor_coords[1]],
                    z=[focus_coords[2], neighbor_coords[2]],
                    mode="lines",
                    line=dict(
                        color=f"rgba(45,55,72,{0.4 - 0.05 * (rank - 1)})",
                        width=max(3 - 0.3 * (rank - 1), 1),
                        dash="solid" if rank <= 2 else "dot",
                    ),
                    name="Similarity Connections" if rank == 1 else None,
                    showlegend=(rank == 1),
                    hoverinfo="skip",
                )
                connection_traces.append(connection_trace)

        # Add connection traces
        for trace in connection_traces:
            fig.add_trace(trace)

        # Add custom legend traces for periods
        legend_y_position = 0.98
        y_step = 0.08  # Vertical spacing between legend items

        for period_name in ["Classical Antiquity", "Middle Ages", "Renaissance", "Enlightenment", "Modern Era", "Other"]:
            color = period_colors.get(period_name, "#BBBBBB")
            symbol = period_symbols.get(period_name, "circle")

            fig.add_trace(
                go.Scatter3d(
                    x=[None],  # Invisible point
                    y=[None],
                    z=[None],
                    mode="markers",
                    marker=dict(
                        size=10,
                        color=color,
                        symbol=symbol,
                        line=dict(color="#F7FAFC", width=1),
                    ),
                    name=period_name,
                    showlegend=True,
                    legendgroup="periods",
                    legendgrouptitle_text="Historical Periods" if period_name == "Classical Antiquity" else None,
                )
            )

        fig.update_layout(
            template="plotly_white",
            paper_bgcolor=background_color,
            plot_bgcolor=background_color,
            scene=dict(
                xaxis=dict(
                    title=dict(text="Dimension 1", font=dict(
                        size=font_size, family="Arial")),
                    gridcolor=grid_color,
                    gridwidth=1.5,
                    showbackground=True,
                    backgroundcolor="rgba(247,250,252,0.8)",
                ),
                yaxis=dict(
                    title=dict(text="Dimension 2", font=dict(
                        size=font_size, family="Arial")),
                    gridcolor=grid_color,
                    gridwidth=1.5,
                    showbackground=True,
                    backgroundcolor="rgba(247,250,252,0.8)",
                ),
                zaxis=dict(
                    title=dict(text="Dimension 3", font=dict(
                        size=font_size, family="Arial")),
                    gridcolor=grid_color,
                    gridwidth=1.5,
                    showbackground=True,
                    backgroundcolor="rgba(247,250,252,0.8)",
                ),
                camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
                aspectmode="cube",
            ),
            title=dict(
                text=f"Scientific Similarity Network: {focus_id}"
                + (f" (showing names for top {max_labeled_neighbors})" if show_neighbor_names else ""),
                x=0.5,
                font=dict(size=title_font_size,
                        family="Arial Black", color=text_color),
            ),
            legend=dict(
                x=0.02,
                y=0.98,
                bgcolor="rgba(255,255,255,0.95)",
                bordercolor=text_color,
                borderwidth=2,
                font=dict(size=font_size - 2, family="Arial"),
                tracegroupgap=5,  # Space between legend groups
            ),
            width=figure_width,
            height=figure_height,
            margin=dict(l=60, r=60, b=80, t=100),
            font=dict(family="Arial", size=font_size, color=text_color),
        )

        fig.show()

        # Cleanup
        del fig, name_to_index, scientist_centuries, connection_traces
        gc.collect()
